Compression of Google Neural Machine Translation Model

Overview

Google Neural Machine Translation (GNMT) is a Sequence to sequence (Seq2seq) model which learns a mapping from an input text to an output text.

The example below demonstrates how to train a highly sparse GNMT model with minimal loss in accuracy. The model is based on the GNMT model presented in the paper Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation 1 which consists of approximately 210M floating point parameters.

GNMT Model

The GNMT architecture is an encoder-decoder architecture with attention as presented in the original paper 1.

The encoder consists of an embedding layer followed by 1 bi-directional and 3 uni-directional LSTM layers with residual connections between them. The decoder consists of an embedding layer followed by 4 uni-directional LSTM layers and a linear Softmax layer. The attention mechanism connects between the encoder’s bi-directional LSTM layer to all of the decoder’s LSTM layers.

The GNMT model was adapted from the model shown in Neural Machine Translation (seq2seq) Tutorial 2 and from its repository.

The Sparse model implementation can be found in GNMTModel and offers several options to build the GNMT model.

Sparsity - Pruning GNMT

Sparse neural networks are networks where a portion of the network weights are zeros. A high sparsity ratio can help compress the model and accelerate inference, reduce power consumption used for memory transfer and computing.

In order to produce a sparse network the network weights are pruned while training by forcing weights to be zero. There are a number of methods to prune neural networks, for example the paper To prune, or not to prune: exploring the efficacy of pruning for model compression 3 presents a method for gradual pruning of weights with low amplitude.

The example below demonstrates how to prune the GNMT model up to 90% sparsity with minimal loss in BLEU score using the Tensorflow model_pruning package which implements the method presented in 3

Post Training Weight Quantization

The weights of pre-trained GNMT models are usually represented in 32bit Floating-point format. The highly sparse pre-trained model below can be further compressed by uniform quantization of the weights to 8bits Integer, gaining a further compression ratio of 4x with negligible accuracy loss. The implementation of the weight quantization is based on TensorFlow API. When using the model for inference, the int8 weights of the sparse and quantized model are de-quantized back to fp32.

Dataset

The models below were trained using the following datasets:

  • Europarlv7 4

  • Common Crawl Corpus

  • News Commentary 11

  • Development and test sets

All datasets are provided by WMT Shared Task: Machine Translation of News

You can use this script wmt16_en_de.sh to download and prepare the data for training and evaluating your model.

Results & Pre-Trained Models

The following table presents some of our experiments and results. We provide pre-trained checkpoints for a 90% sparse GNMT model and a similar 90% sparse but with 2x2 sparsity blocks pattern. See table below and our Model Zoo. You can use these models to Run Inference using our Pre-Trained Models and evaluate them.

Model

Sparsity

BLEU

Non-Zero Parameters

Data Type

Baseline

0%

29.9

~210M

Float32

Sparse

90%

28.4

~22M

Float32

2x2 Block Sparse

90%

27.8

~22M

Float32

Quantized Sparse

90%

28.4

~22M

Integer8

Quantized 2x2 Block Sparse

90%

27.6

~22M

Integer8

  1. The pruning is applied to the embedding, decoder projection layer and all LSTM layers in both the encoder and decoder.

  2. BLEU score is measured using newstest2015 test set provided by the Shared Task.

  3. The accuracy of the quantized model was measure when we converted the 8 bits weights back to floating point during inference.

Running Modalities

Below are simple examples for training 90% sparse GNMTModel model, running inference using a pre-trained/trained model, quantizing a model to 8bit Integer and running inference using a quantized model. Before inference, the int8 weights of the sparse and quantized model are de-quantize back to fp32.

Training

Train a German to English GNMT model with 90% sparsity using the WMT16 dataset:

# Download the dataset
wmt16_en_de.sh /tmp/wmt16_en_de

# Go to examples directory
cd <nlp_architect root>/examples

# Train the sparse GNMT
python -m sparse_gnmt.nmt \
    --src=de --tgt=en \
    --hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
    --out_dir=<output directory> \
    --vocab_prefix=/tmp/wmt16_en_de/vocab.bpe.32000 \
    --train_prefix=/tmp/wmt16_en_de/train.tok.clean.bpe.32000 \
    --dev_prefix=/tmp/wmt16_en_de/newstest2013.tok.bpe.32000 \
    --test_prefix=/tmp/wmt16_en_de/newstest2015.tok.bpe.32000
  • Train using GPUs by adding --num_gpus=<n>

  • Model configuration JSON files are found in examples/sparse_gnmt/standard_hparams directory.

  • Sparsity policy can be re-configured by changing the parameters given in --pruning_hparams. E.g. change target_policy=0.7 in order to train 70% sparse GNMT.

  • All pruning hyper parameters are listed in model_pruning.

While training Tensorflow checkpoints, Tensorboard events, Hyper-Parameters used and log files will be saved in the output directory given.

Inference

Run inference using a trained model:

# Go to examples directory
cd <nlp_architect root>/examples

# Run Inference
python -m sparse_gnmt.nmt \
--src=de --tgt=en \
--hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
--ckpt=<path to a trained checkpoint> \
--vocab_prefix=/tmp/wmt16_en_de/vocab.bpe.32000 \
--out_dir=<output directory> \
--inference_input_file=<file with lines in the source language> \
--inference_output_file=<target file to place translations>
  • Measure performance and BLEU score against a reference file by adding --inference_ref_file=<reference file in the target language>

  • Inference using GPUs by adding --num_gpus=<n>

Run Inference using our Pre-Trained Models

Run inference using our pre-trained models:

# Download pre-trained model zip file, e.g. gnmt_sparse.zip
wget https://s3-us-west-2.amazonaws.com/nlp-architect-data/models/sparse_gnmt/gnmt_sparse.zip

# Unzip checkpoint + vocabulary files
unzip gnmt_sparse.zip -d /tmp/gnmt_sparse_checkpoint

# Go to examples directory
cd <nlp_architect root>/examples

# Run Inference
python -m sparse_gnmt.nmt \
    --src=de --tgt=en \
    --hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
    --ckpt=/tmp/gnmt_sparse_checkpoint/gnmt_sparse.ckpt\
    --vocab_prefix=/tmp/gnmt_sparse_checkpoint/vocab.bpe.32000 \
    --out_dir=<output directory> \
    --inference_input_file=<file with lines in the source language> \
    --inference_output_file=<target file to place translations>

Important Note: use the vocabulary files provided with the checkpoint when using our pre-trained models

Quantized Inference

Add the following flags to the Inference command line in order to quantize the pre-trained models and run inference with the quantized models:

  • --quantize_ckpt=true: Produce a quantized checkpoint. Checkpoint will be saved in the output directory. Inference will run using the produced checkpoint.

  • --from_quantized_ckpt=true: Inference using an already quantized checkpoint

Custom Training/Inference Parameters

All customizable parameters can be obtained by running: python -m nlp-architect.examples.sparse_gnmt.nmt -h

-h, --help

show this help message and exit

--num_units NUM_UNITS

Network size.

--num_layers NUM_LAYERS

Network depth.

--num_encoder_layers NUM_ENCODER_LAYERS

Encoder depth, equal to num_layers if None.

--num_decoder_layers NUM_DECODER_LAYERS

Decoder depth, equal to num_layers if None.

--encoder_type

uni | bi | gnmt. For bi, we build num_encoder_layers/2 bi-directional layers. For gnmt, we build 1 bi- directional layer, and (num_encoder_layers - 1) uni- directional layers.

--residual

Whether to add residual connections.

--time_major

Whether to use time-major mode for dynamic RNN.

--num_embeddings_partitions NUM_EMBEDDINGS_PARTITIONS

Number of partitions for embedding vars.

--attention

luong | scaled_luong | bahdanau | normed_bahdanau or set to “” for no attention

--attention_architecture

standard | gnmt | gnmt_v2. standard: use top layer to compute attention. gnmt: GNMT style of computing attention, use previous bottom layer to compute attention. gnmt_v2: similar to gnmt, but use current bottom layer to compute attention.

--output_attention

Only used in standard attention_architecture. Whether use attention as the cell output at each timestep.

--pass_hidden_state

Whether to pass encoder’s hidden state to decoder when using an attention based model.

--optimizer

sgd | adam

--learning_rate LEARNING_RATE

Learning rate. Adam: 0.001 | 0.0001

--warmup_steps WARMUP_STEPS

How many steps we inverse-decay learning.

--warmup_scheme

How to warmup learning rates. Options include: t2t: Tensor2Tensor’s way, start with lr 100 times smaller, then exponentiate until the specified lr.

--decay_scheme

How we decay learning rate. Options include: luong234: after 2/3 num train steps, we start halving the learning rate for 4 times before finishing. luong5: after 1/2 num train steps, we start halving the learning rate for 5 times before finishing. luong10: after 1/2 num train steps, we start halving the learning rate for 10 times before finishing.

--num_train_steps NUM_TRAIN_STEPS

Num steps to train.

--colocate_gradients_with_ops

Whether try colocating gradients with corresponding op

--init_op

uniform | glorot_normal | glorot_uniform

--init_weight INIT_WEIGHT

for uniform init_op, initialize weights between .

--src SRC

Source suffix, e.g., en.

--tgt TGT

Target suffix, e.g., de.

--train_prefix TRAIN_PREFIX

Train prefix, expect files with src/tgt suffixes.

--dev_prefix DEV_PREFIX

Dev prefix, expect files with src/tgt suffixes.

--test_prefix TEST_PREFIX

Test prefix, expect files with src/tgt suffixes.

--out_dir OUT_DIR

Store log/model files.

--vocab_prefix VOCAB_PREFIX

Vocab prefix, expect files with src/tgt suffixes.

--embed_prefix EMBED_PREFIX

Pretrained embedding prefix, expect files with src/tgt suffixes. The embedding files should be Glove formatted txt files.

--sos SOS

Start-of-sentence symbol.

--eos EOS

End-of-sentence symbol.

--share_vocab

Whether to use the source vocab and embeddings for both source and target.

--check_special_token CHECK_SPECIAL_TOKEN

Whether check special sos, eos, unk tokens exist in the vocab files.

--src_max_len SRC_MAX_LEN

Max length of src sequences during training.

--tgt_max_len TGT_MAX_LEN

Max length of tgt sequences during training.

--src_max_len_infer SRC_MAX_LEN_INFER

Max length of src sequences during inference.

--tgt_max_len_infer TGT_MAX_LEN_INFER

Max length of tgt sequences during inference. Also used to restrict the maximum decoding length.

--unit_type

lstm | gru | layer_norm_lstm | nas | mlstm

--projection_type

dense | sparse

--embedding_type

dense | sparse

--forget_bias FORGET_BIAS

Forget bias for BasicLSTMCell.

--dropout DROPOUT

Dropout rate (not keep_prob)

--max_gradient_norm MAX_GRADIENT_NORM

Clip gradients to this norm.

--batch_size BATCH_SIZE

Batch size.

--steps_per_stats STEPS_PER_STATS

How many training steps to do per stats logging.Save checkpoint every 10x steps_per_stats

--max_train MAX_TRAIN

Limit on the size of training data (0: no limit).

--num_buckets NUM_BUCKETS

Put data into similar-length buckets.

--num_sampled_softmax NUM_SAMPLED_SOFTMAX

Use sampled_softmax_loss if > 0.Otherwise, use full softmax loss.

--subword_option

Set to bpe or spm to activate subword desegmentation.

--use_char_encode USE_CHAR_ENCODE

Whether to split each word or bpe into character, and then generate the word-level representation from the character representation.

--num_gpus NUM_GPUS

Number of gpus in each worker.

--log_device_placement

Debug GPU allocation.

--metrics METRICS

Comma-separated list of evaluations metrics (bleu,rouge,accuracy)

--steps_per_external_eval STEPS_PER_EXTERNAL_EVAL

How many training steps to do per external evaluation. Automatically set based on data if None.

--scope SCOPE

scope to put variables under

--hparams_path HPARAMS_PATH

Path to standard hparams json file that overrides hparams values from FLAGS.

--random_seed RANDOM_SEED

Random seed (>0, set a specific seed).

--override_loaded_hparams

Override loaded hparams with values specified

--num_keep_ckpts NUM_KEEP_CKPTS

Max number of checkpoints to keep.

--avg_ckpts

Average the last N checkpoints for external evaluation. N can be controlled by setting –num_keep_ckpts.

--language_model

True to train a language model, ignoring encoder

--ckpt CKPT

Checkpoint file to load a model for inference.

--quantize_ckpt QUANTIZE_CKPT

Set to True to produce a quantized checkpoint from existing checkpoint

--from_quantized_ckpt FROM_QUANTIZED_CKPT

Set to True when the given checkpoint is quantized

--inference_input_file INFERENCE_INPUT_FILE

Set to the text to decode.

--inference_list INFERENCE_LIST

A comma-separated list of sentence indices (0-based) to decode.

--infer_batch_size INFER_BATCH_SIZE

Batch size for inference mode.

--inference_output_file INFERENCE_OUTPUT_FILE

Output file to store decoding results.

--inference_ref_file INFERENCE_REF_FILE

Reference file to compute evaluation scores (if provided).

--infer_mode

Which type of decoder to use during inference.

--beam_width BEAM_WIDTH

beam width when using beam search decoder. If 0 (default), use standard decoder with greedy helper.

--length_penalty_weight LENGTH_PENALTY_WEIGHT

Length penalty for beam search.

--sampling_temperature SAMPLING_TEMPERATURE

Softmax sampling temperature for inference decoding, 0.0 means greedy decoding. This option is ignored when using beam search.

--num_translations_per_input NUM_TRANSLATIONS_PER_INPUT

Number of translations generated for each sentence. This is only used for inference.

--jobid JOBID

Task id of the worker.

--num_workers NUM_WORKERS

Number of workers (inference only).

--num_inter_threads NUM_INTER_THREADS

number of inter_op_parallelism_threads

--num_intra_threads NUM_INTRA_THREADS

number of intra_op_parallelism_threads

--pruning_hparams PRUNING_HPARAMS

model pruning parameters

References

1(1,2)

Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others. Google’s neural machine translation system: Bridging the gap between human and machine translation. https://arxiv.org/pdf/1609.08144.pdf

2

Minh-Thang Luong and Eugene Brevdo and Rui Zhao. Neural Machine Translation (seq2seq) Tutorial. https://github.com/tensorflow/nmt

3(1,2)

Zhu, Michael and Gupta, Suyog. To prune, or not to prune: exploring the efficacy of pruning for model compression. https://arxiv.org/pdf/1710.01878.pdf

4

A Parallel Corpus for Statistical Machine Translation, Philipp Koehn, MT Summit 2005